#!/usr/bin/python

# TODO:
# ! add option for padding
# - fix occasionally missing page numbers
# - treat large h-whitespace as separator
# - handle overlapping candidates
# - use cc distance statistics instead of character scale
# - page frame detection
# - read and use text image segmentation mask
# - pick up stragglers
# ? laplacian as well

import pdb
from pylab import *
import argparse,glob,os,os.path
from scipy.ndimage import filters,interpolation,morphology,measurements
from scipy import stats
from scipy.misc import imsave
from scipy.ndimage.filters import gaussian_filter,uniform_filter,maximum_filter,minimum_filter
import ocrolib
from ocrolib import psegutils,morph,improc,sl,lineproc
import multiprocessing
from multiprocessing import Pool
from ocrolib.toplevel import *

parser = argparse.ArgumentParser()
parser.add_argument('-z','--zoom',type=float,default=0.5,help='zoom for page background estimation, smaller=faster')

parser.add_argument('--show',type=float,default=0,help='show the final output')
parser.add_argument('--gray',action='store_true',help='output grayscale lines as well (%(default)s)')
parser.add_argument('--usefilename',action='store_true',help='use the input filename, instead of base + .bin.png (%(default)s)')
parser.add_argument('-q','--quiet',action='store_true',help='be less verbose (%(default)s)')

# limits
parser.add_argument('--minscale',type=float,default=12.0,help='minimum scale permitted (%(default)s)')
parser.add_argument('--maxlines',type=float,default=300,help='maximum # lines permitted (%(default)s)')

# scale parameters
parser.add_argument('--scale',type=float,default=0.0,help='the basic scale of the document (roughly, xheight) 0=automatic (%(default)s)')
parser.add_argument('--hscale',type=float,default=1.0,help='non-standard scaling of horizontal parameters (%(default)s)')
parser.add_argument('--vscale',type=float,default=1.0,help='non-standard scaling of vertical parameters (%(default)s)')
parser.add_argument('--debugcleaned',action="store_true")

# line parameters
parser.add_argument('--debuglines',action="store_true")
parser.add_argument('--debugreadingorder',action="store_true")
parser.add_argument('--threshold',type=float,default=0.2,help='baseline threshold (%(default)s)')
parser.add_argument('--noise',type=int,default=8,help="noise threshold for removing small components from lines (%(default)s)")
parser.add_argument('--usegauss',action='store_true',help='use gaussian instead of uniform (%(default)s)')

# column parameters
parser.add_argument('--debugseps',action="store_true",help="debug black column separator computation in detail")
parser.add_argument('--maxseps',type=int,default=2,help='maximum black column separators (%(default)s)')
parser.add_argument('--sepwiden',type=int,default=10,help='widen black separators (to account for warping) (%(default)s)')

# whitespace column separators
parser.add_argument('--debugcolseps',action="store_true",help="debug whitespace column separator computation in detail")
parser.add_argument('--debugcols',action="store_true",help="debug the results of whitespace column separator computation")
parser.add_argument('--maxcolseps',type=int,default=2,help='maximum # whitespace column separators (%(default)s)')
parser.add_argument('--csmaxwidth',type=float,default=10,help='maximum column width (units=scale) (%(default)s)')
parser.add_argument('--csminheight',type=float,default=20,help='minimum column height (units=scale) (%(default)s)')

# wait for input after everything is done
parser.add_argument("--debugwait",action="store_true")

parser.add_argument('-p','--pad',type=int,default=3,help='padding for extracted lines (%(default)s)')
parser.add_argument('-e','--expand',type=int,default=3,help='expand mask for grayscale extraction (%(default)s)')
parser.add_argument('-Q','--parallel',type=int,default=0)
parser.add_argument('files',nargs='+')
args = parser.parse_args()
args.files = ocrolib.glob_all(args.files)

if args.show>0:
    args.parallel = 1
    ion(); gray()

if args.parallel>1:
    args.quiet = 1

def B(a):
    if a.dtype==dtype('B'): return a
    return array(a,'B')

def imfigure(title,image):
    figure(title)
    gray()
    imshow(image)
    ginput(1,0.1)



################################################################
### Column finding.
###
### This attempts to find column separators, either as extended
### vertical black lines or extended vertical whitespace.
### It will work fairly well in simple cases, but for unusual
### documents, you need to tune the parameters.
################################################################

def compute_separators_morph(binary,scale):
    """Finds vertical black lines corresponding to column separators."""
    d0 = int(max(5,scale/4))
    d1 = int(max(5,scale))+args.sepwiden
    thick = morph.r_dilation(binary,(d0,d1))
    vert = morph.rb_opening(thick,(10*scale,1))
    vert = morph.r_erosion(vert,(d0//2,args.sepwiden))
    vert = morph.select_regions(vert,sl.dim1,min=3,nbest=2*args.maxseps)
    vert = morph.select_regions(vert,sl.dim0,min=20*scale,nbest=args.maxseps)
    return vert

def compute_colseps_morph(binary,scale,debug=0,maxseps=3,minheight=20,maxwidth=5):
    """Finds extended vertical whitespace corresponding to column separators."""
    boxmap = psegutils.compute_boxmap(binary,scale,dtype='B')
    bounds = morph.rb_closing(B(boxmap),(int(5*scale),int(5*scale)))
    bounds = maximum(B(1-bounds),B(boxmap))
    cols = 1-morph.rb_closing(boxmap,(int(20*scale),int(scale)))
    cols = morph.select_regions(cols,lambda x:-sl.dim1(x),min=-args.csmaxwidth*scale)
    cols = morph.select_regions(cols,sl.dim0,min=args.csminheight*scale,nbest=args.maxcolseps)
    cols = morph.r_erosion(cols,(int(0.5+scale),0))
    cols = morph.r_dilation(cols,(int(0.5+scale),0),origin=(int(scale/2)-1,0))
    return cols

def compute_colseps(binary,scale):
    """Computes column separators either from vertical black lines or whitespace."""
    seps = compute_separators_morph(binary,scale)
    if args.debugcols: imfigure("column separators",0.7*seps+0.3*binary)
    colseps = compute_colseps_morph(binary,scale)
    if args.debugcols: imfigure("column whitespace separators",0.7*colseps+0.3*binary)
    colseps = maximum(colseps,seps)
    binary = minimum(binary,1-seps)
    return colseps,binary



################################################################
### Text Line Finding.
###
### This identifies the tops and bottoms of text lines by
### computing gradients and performing some adaptive thresholding.
### Those components are then used as seeds for the text lines.
################################################################

def compute_gradmaps(binary,scale):
    # use gradient filtering to find baselines
    boxmap = psegutils.compute_boxmap(binary,scale)
    cleaned = boxmap*binary
    if args.debugcleaned:
        figure("debug-cleaned")
        clf(); title("cleaned"); imshow(cleaned)
    if args.usegauss:
        # this uses Gaussians
        grad = gaussian_filter(1.0*cleaned,(args.vscale*0.3*scale,
                                            args.hscale*6*scale),order=(1,0))
    else:
        # this uses non-Gaussian oriented filters
        grad = gaussian_filter(1.0*cleaned,(max(4,args.vscale*0.3*scale),
                                            args.hscale*scale),order=(1,0))
        grad = uniform_filter(grad,(args.vscale,args.hscale*6*scale))
    bottom = improc.norm_max((grad<0)*(-grad))
    top = improc.norm_max((grad>0)*grad)
    return bottom,top,boxmap

def compute_line_seeds(binary,bottom,top,colseps,scale):
    """Base on gradient maps, computes candidates for baselines
    and xheights.  Then, it marks the regions between the two
    as a line seed."""
    t = args.threshold
    vrange = int(args.vscale*scale)
    bmarked = maximum_filter(bottom==maximum_filter(bottom,(vrange,0)),(2,2))
    bmarked *= ((bottom>t*amax(bottom)*t)*(1-colseps)).astype(bool)
    tmarked = maximum_filter(top==maximum_filter(top,(vrange,0)),(2,2))
    tmarked *= ((top>t*amax(top)*t/2)*(1-colseps)).astype(bool)
    tmarked = maximum_filter(tmarked,(1,20))
    seeds = zeros(binary.shape,'i')
    delta = max(3,int(scale/2))
    for x in range(bmarked.shape[1]):
        transitions = sorted([(y,1) for y in find(bmarked[:,x])]+[(y,0) for y in find(tmarked[:,x])])[::-1]
        transitions += [(0,0)]
        for l in range(len(transitions)-1):
            y0,s0 = transitions[l]
            if s0==0: continue
            seeds[y0-delta:y0,x] = 1
            y1,s1 = transitions[l+1]
            if s1==0 and (y0-y1)<5*scale: seeds[y1:y0,x] = 1
    seeds = maximum_filter(seeds,(1,int(1+scale)))
    seeds *= (1-colseps)
    if args.debuglines:
        figure("debug-lineseeds")
        ocrolib.showrgb(seeds,0.3*tmarked+0.7*bmarked,binary)
        ginput(1,0.1)
    seeds,_ = morph.label(seeds)
    return seeds



################################################################
### The complete line segmentation process.
################################################################

def remove_hlines(binary,scale,maxsize=10):
    labels,_ = morph.label(binary)
    objects = morph.find_objects(labels)
    for i,b in enumerate(objects):
        if sl.width(b)>maxsize*scale:
            labels[b][labels[b]==i+1] = 0
    return array(labels!=0,'B')

def compute_segmentation(binary,scale):
    """Given a binary image, compute a complete segmentation into
    lines, computing both columns and text lines."""
    binary = array(binary,'B')

    # start by removing horizontal black lines, which only
    # interfere with the rest of the page segmentation
    binary = remove_hlines(binary,scale)

    # do the column finding
    if not args.quiet: print "computing column separators"
    colseps,binary = compute_colseps(binary,scale)

    # now compute the text line seeds
    if not args.quiet: print "computing lines"
    bottom,top,boxmap = compute_gradmaps(binary,scale)
    seeds = compute_line_seeds(binary,bottom,top,colseps,scale)
    if args.debuglines: 
        figure("seeds")
        ocrolib.showrgb(bottom,top,boxmap)

    # spread the text line seeds to all the remaining
    # components
    if not args.quiet: print "propagating labels"
    llabels = morph.propagate_labels(boxmap,seeds,conflict=0)
    if not args.quiet: print "spreading labels"
    spread = morph.spread_labels(seeds,maxdist=scale)
    llabels = where(llabels>0,llabels,spread*binary)
    segmentation = llabels*binary
    return segmentation



################################################################
### Processing each file.
################################################################

def process1(job):
    fname,i = job
    base,_ = ocrolib.allsplitext(fname)
    outputdir = base

    if args.usefilename:
        binary = ocrolib.read_image_binary(fname)
    else:
        binary = ocrolib.read_image_binary(base+".bin.png")

    checktype(binary,ABINARY2)
 
    if args.gray:
        if os.path.exists(base+".nrm.png"):
            gray = ocrolib.read_image_gray(base+".nrm.png")
        checktype(gray,GRAYSCALE)

    binary = 1-binary # invert

    if args.scale==0:
        scale = psegutils.estimate_scale(binary)
    else:
        scale = args.scale
    if scale<args.minscale:
        sys.stderr.write("%s: scale (%g) less than --minscale; skipping\n"%(fname,scale))
        return

    # find columns and text lines

    if not args.quiet: print "computing segmentation"
    segmentation = compute_segmentation(binary,scale)
    if amax(segmentation)>args.maxlines: 
        print fname,": too many lines",amax(segmentation)
        return 
    if not args.quiet: print "number of lines",amax(segmentation)

    # compute the reading order

    if not args.quiet: print "finding reading order"
    lines = psegutils.compute_lines(segmentation,scale)
    order = psegutils.reading_order([l.bounds for l in lines],debug=args.debugreadingorder)
    lsort = psegutils.topsort(order)

    # renumber the labels so that they conform to the specs

    nlabels = amax(segmentation)+1
    renumber = zeros(nlabels,'i')
    for i,v in enumerate(lsort): renumber[lines[v].label] = 0x010000+(i+1)
    segmentation = renumber[segmentation]

    # finally, output everything

    if args.show:
        figure("output")
        clf(); title("output"); psegutils.show_lines(binary,lines,lsort)

    if not args.quiet: print "writing lines"
    if not os.path.exists(outputdir):
        os.mkdir(outputdir)
    lines = [lines[i] for i in lsort]
    ocrolib.write_page_segmentation("%s.pseg.png"%outputdir,segmentation)
    cleaned = improc.remove_noise(binary,args.noise)
    for i,l in enumerate(lines):
        binline = psegutils.extract_masked(1-cleaned,l,pad=args.pad,expand=args.expand)
        ocrolib.write_image_binary("%s/01%04x.bin.png"%(outputdir,i+1),binline)
        if args.gray:
            grayline = psegutils.extract_masked(gray,l,pad=args.pad,expand=args.expand)
            ocrolib.write_image_gray("%s/01%04x.nrm.png"%(outputdir,i+1),grayline)
    print "%6d"%i,fname,"%4.1f"%scale,len(lines)
    if args.debugwait: 
        ginput(1,0.1)
        print "hit return for next image"
        raw_input()

if len(args.files)==1 and os.path.isdir(args.files[0]):
    files = glob.glob(args.files[0]+"/????.png")
else:
    files = args.files

if args.parallel<2:
    count = 0
    for i,f in enumerate(files):
        if args.parallel==0: print f
        count += 1
        process1((f,i+1))
else:
    pool = Pool(processes=args.parallel)
    jobs = []
    for i,f in enumerate(files): jobs += [(f,i+1)]
    result = pool.map(process1,jobs)
